import numpy as np
import torch
import torch.optim as optim
from sklearn.neighbors import kneighbors_graph
from typing import List
import os
import re
import math
import matplotlib.pyplot as plt

def fmt_sigma(s):
    return str(s).replace('.', 'p')

def generate_circle_given_X(n_samples, sigma, X, radius = 3):
    
    U = np.random.uniform(0, 2 * np.pi, size=(n_samples, 1))
    epsilon = np.random.normal(0, sigma, size=(n_samples, 2))

    Y1 = X + radius * np.sin(2 * U) + epsilon[:, 0:1]
    Y2 = X + radius * np.cos(2 * U) + epsilon[:, 1:2]

    Y = np.hstack((Y1, Y2))

    return Y

def generate_circle_data(n_samples, sigma, radius = 3):
    
    X = np.random.randn(n_samples, 1)
    U = np.random.uniform(0, 2 * np.pi, size=(n_samples, 1))
    epsilon = np.random.normal(0, sigma, size=(n_samples, 2))

    Y1 = X + radius * np.sin(2 * U) + epsilon[:, 0:1]
    Y2 = X + radius * np.cos(2 * U) + epsilon[:, 1:2]

    Y = np.hstack((Y1, Y2))
    return X, Y

def train_gcds(X_train, Y_train, eta_train, n_samples, Generator, Discriminator,
               seed=42, batch_size=200, epochs=2000, lr=1e-3, print_interval=50):
    
    torch.manual_seed(seed)

    # Initialize the networks
    G = Generator()
    D = Discriminator()

    # Setup optimizers for Generator and Discriminator
    optimizer_G = optim.Adam(G.parameters(), lr=lr)
    optimizer_D = optim.Adam(D.parameters(), lr=lr)

    for epoch in range(epochs):
        # Randomly permute the indices for this epoch
        permutation = torch.randperm(n_samples)

        for i in range(0, n_samples, batch_size):
            indices = permutation[i:i+batch_size]
            X_batch = X_train[indices]
            Y_batch = Y_train[indices]
            eta_batch = eta_train[indices]

            # Generate fake samples using the generator with fixed eta
            Y_fake = G(X_batch, eta_batch)

            # ----- Train the Discriminator -----
            optimizer_D.zero_grad()
            real_output = D(X_batch, Y_batch)
            fake_output = D(X_batch, Y_fake.detach())  
            loss_D = torch.mean(fake_output) - torch.mean(torch.exp(real_output))

            (-loss_D).backward()
            optimizer_D.step()

            # ----- Train the Generator -----
            optimizer_G.zero_grad()
            fake_output_for_G = D(X_batch, Y_fake)
            loss_G = torch.mean(fake_output_for_G)
            loss_G.backward()
            optimizer_G.step()

        if epoch % print_interval == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch}/{epochs}: Loss_D={loss_D.item():.4f}, Loss_G={loss_G.item():.4f}")

    return G, D

def gaussian_kernel(A: torch.Tensor, B: torch.Tensor, sigma: float = 1) -> torch.Tensor:
    
    sigma = torch.tensor(sigma, dtype=torch.float32)
    diff = A.unsqueeze(1) - B.unsqueeze(0)
    return torch.exp(-sigma * torch.norm(diff, dim=2) ** 2)

def ECMMD(Z, Y, X, batch, kernel, neighbors: int, kernel_band = 1) -> torch.Tensor:

    N_X_sparse = kneighbors_graph(X.detach().cpu().numpy(), neighbors, include_self=False)
    N_X_sparse = N_X_sparse.tocoo()
    indices = torch.from_numpy(np.array([N_X_sparse.row, N_X_sparse.col])).long()
    values = torch.tensor(N_X_sparse.data, dtype=torch.float32)
    shape = torch.Size(N_X_sparse.shape)
    N_X = torch.sparse_coo_tensor(indices, values, shape)

    # Compute kernel matrices using the differentiable kernel function
    kernel_ZZ = kernel(Z, Z, sigma = kernel_band)
    kernel_YY = kernel(Y, Y, sigma = kernel_band)
    kernel_ZY = kernel(Z, Y, sigma = kernel_band)
    kernel_YZ = kernel(Y, Z, sigma = kernel_band)

    # Compute the H matrix (element-wise difference of kernels)
    H = kernel_ZZ + kernel_YY - kernel_ZY - kernel_YZ

    # Use the sparse indices to efficiently sum only over the neighbor entries
    sparse_indices = N_X._indices()
    total = torch.sum(H[sparse_indices[0], sparse_indices[1]])

    # Compute ECMMD loss by normalizing with the total number of neighbor entries.
    ECMMD_value = total / (batch * neighbors)

    return ECMMD_value

def train_ecmmd_model(X_train, Y_train, eta_train, n_samples, Generator,
                      seed=42, lr=1e-3, epochs=1000, batch_size=200, neighbors=15,
                      kernel_function=None, kernel_band=1, print_interval=50,
                      checkpoint=False, checkpoint_dir=None, num_checkpoints=5,
                      ckpt_name_fmt="ECMMD_Generator_{tag}_ep{epoch:04d}.pth",tag=None):
    
    torch.manual_seed(seed)
    model = Generator()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # --- set up checkpoints while training ---
    saved_ckpt_paths = []
    if checkpoint and num_checkpoints > 0 and epochs > 0:
        ckpt_epochs = np.linspace(1, 125, num=num_checkpoints, endpoint=True, dtype=int)
        ckpt_epochs = sorted(set(int(e) for e in ckpt_epochs))
        if checkpoint_dir is not None:
            os.makedirs(checkpoint_dir, exist_ok=True)
    else:
        ckpt_epochs = []
    
    for epoch in range(epochs):
        
        permutation = torch.randperm(n_samples)
        
        for i in range(0, n_samples, batch_size):
            indices = permutation[i:i+batch_size]
            X_batch = X_train[indices]
            Y_batch = Y_train[indices]
            eta_batch = eta_train[indices]

            # Generate fake outputs using the model.
            Y_fake = model(X_batch, eta_batch)
            
            # ECMMD loss
            loss = ECMMD(Y_batch, Y_fake, X_batch, batch_size, kernel_function, neighbors, kernel_band=kernel_band)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        if epoch % print_interval == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch}/{epochs}: Loss={loss.item():.4f}")

        # Checkpoint saving
        ep1 = epoch + 1
        if ep1 in ckpt_epochs and checkpoint and checkpoint_dir is not None:
            fname = ckpt_name_fmt.format(tag=(tag or "run"), epoch=ep1)
            fpath = os.path.join(checkpoint_dir, fname)
            torch.save(model.state_dict(), fpath)
            saved_ckpt_paths.append(fpath)
            print(f"[ECMMD] checkpoint saved: {fpath}")

    return model, saved_ckpt_paths

def plot_generated_and_test_samples(
    seed,
    G1,
    G2,
    noise_dim,
    sigma,
    generate_circle_given_X,
    x_eval=1,
    sample_size=1000,
    s=1,
    alpha=0.5,
    save_dir_fig="generated_data",
    save_dir_csv="figures",
    panel_basename="circle_comparison",
    save_csv=True,
    return_fig=False,
):
    os.makedirs(save_dir_fig, exist_ok=True)
    os.makedirs(save_dir_csv, exist_ok=True)

    if np.isscalar(sigma):
        sigmas = [float(sigma)]
    else:
        sigmas = [float(x) for x in sigma]

    nrows = len(sigmas)

    # --------- Figure and label styling parameters ----------
    FIG_WIDTH = 14           
    ROW_HEIGHT = 3.2         
    ROW_LABEL_FONTSIZE = 20  
    ROW_LABEL_OFFSET = 0.40  
    ROW_GAP_HSPACE = 0.95    
    BOTTOM_MARGIN = 0.12     
    # ---------------------------------------------------------------------

    def _get_model(M, sig):
        if isinstance(M, dict):
            return M[sig]
        if callable(M):
            return M(sig)
        return M

    fig, axes = plt.subplots(nrows=nrows, ncols=3, figsize=(FIG_WIDTH, ROW_HEIGHT * nrows))
    if nrows == 1:
        axes = np.array([axes]) 

    for r, sig in enumerate(sigmas):

        torch.manual_seed(seed + r)
        np.random.seed(seed + r)

        g1 = _get_model(G1, sig)
        g2 = _get_model(G2, sig)
        g1.eval(); g2.eval()

        def _dev(model):
            try:
                return next(model.parameters()).device
            except StopIteration:
                return torch.device("cpu")
        dev1, dev2 = _dev(g1), _dev(g2)

        # Inputs for G1
        X_eval_vec_1 = torch.full((sample_size, 1), fill_value=x_eval, dtype=torch.float32, device=dev1)
        eta_eval_1   = torch.randn(sample_size, noise_dim, device=dev1)
        with torch.no_grad():
            Y_gen1 = g1(X_eval_vec_1, eta_eval_1).detach().cpu().numpy()

        # Inputs for G2
        X_eval_vec_2 = torch.full((sample_size, 1), fill_value=x_eval, dtype=torch.float32, device=dev2)
        eta_eval_2   = torch.randn(sample_size, noise_dim, device=dev2)
        with torch.no_grad():
            Y_gen2 = g2(X_eval_vec_2, eta_eval_2).detach().cpu().numpy()

        # Test data
        Y_test = generate_circle_given_X(sample_size, sig, x_eval)

        # Save data generated
        if save_csv:
            tag = fmt_sigma(sig)  # assumes fmt_sigma is defined elsewhere
            np.savetxt(os.path.join(save_dir_csv, f"gcds_generated_samples_sigma{tag}.csv"), Y_gen1, delimiter=",")
            np.savetxt(os.path.join(save_dir_csv, f"ecmmd_generated_samples_sigma{tag}.csv"), Y_gen2, delimiter=",")
            np.savetxt(os.path.join(save_dir_csv, f"test_samples_sigma{tag}.csv"), Y_test, delimiter=",")

        ax0, ax1, ax2 = axes[r]

        # GCDS
        ax0.scatter(Y_gen1[:, 0], Y_gen1[:, 1], s=s, alpha=alpha)
        ax0.set_title(f"Generated by GCDS", fontsize=20)
        ax0.set_xlabel(r"$y_1$", fontsize=22)
        ax0.set_ylabel(r"$y_2$", fontsize=22)
        ax0.grid(True)

        # ECMMD 
        ax1.scatter(Y_gen2[:, 0], Y_gen2[:, 1], s=s, alpha=alpha)
        ax1.set_title(f"Generated by ECMMD", fontsize=20)
        ax1.set_xlabel(r"$y_1$", fontsize=22)
        ax1.set_ylabel(r"$y_2$", fontsize=22)
        ax1.grid(True)

        # Test 
        ax2.scatter(Y_test[:, 0], Y_test[:, 1], s=s, alpha=alpha)
        ax2.set_title(f"Test Samples", fontsize=20)
        ax2.set_xlabel(r"$y_1$", fontsize=22)
        ax2.set_ylabel(r"$y_2$", fontsize=22)
        ax2.grid(True)

        label_prefix = f"({chr(97 + r)}) "  # (a), (b), (c), ...
        ax1.annotate(
            rf"{label_prefix}$\sigma = {sig}$",
            xy=(0.5, -ROW_LABEL_OFFSET), xycoords="axes fraction",
            ha="center", va="top",
            fontsize=ROW_LABEL_FONTSIZE
        )

    plt.tight_layout()
    
    fig.subplots_adjust(hspace=ROW_GAP_HSPACE, bottom=BOTTOM_MARGIN)

    if nrows == 1:
        out_pdf = os.path.join(save_dir_fig, f"{panel_basename}_sigma={sigmas[0]}.pdf")
        plt.savefig(out_pdf, format="pdf", bbox_inches="tight")
    else:
        tag_all = "_".join(fmt_sigma(s) for s in sigmas)
        out_pdf = os.path.join(save_dir_fig, f"{panel_basename}_sigmas_{tag_all}.pdf")
        plt.savefig(out_pdf, format="pdf", bbox_inches="tight")

    plt.show()

    if return_fig:
        return fig


def plot_ecmmd_checkpoints_by_sigma(
    *,
    seed: int,
    sigmas,                                
    checkpoint_map,                         
    GeneratorCtor,                          
    noise_dim: int,
    generate_circle_given_X,                
    x_eval: float = 1.0,
    sample_size: int = 1000,
    device: torch.device | str = "cpu",
    # figure & IO
    save_dir_fig: str = "generated_data",
    save_dir_csv: str = "figures",
    panel_basename: str = "ecmmd_checkpoints",
    add_test_column: bool = True,           
    save_csv: bool = True,
    return_fig: bool = False,
    s: float = 2.0,
    alpha: float = 0.6,
    fig_width: float = 12.0,
    row_height: float = 3.0,
    row_label_fontsize: int = 16,
    row_gap_hspace: float = 1.0,
    bottom_margin: float = 0.12,
):

    os.makedirs(save_dir_fig, exist_ok=True)
    os.makedirs(save_dir_csv, exist_ok=True)

    def _fmt_sigma(sig: float) -> str:
        return f"{sig:.3f}".rstrip("0").rstrip(".").replace(".", "p")

    _epoch_pat = re.compile(r"ep(?:och)?[_-]?(\d{1,6})", re.IGNORECASE)

    # --- append path and epoch --- 
    def _normalize_ckpts(ckpt_list):
        out = []
        for idx, item in enumerate(ckpt_list):
            if isinstance(item, (tuple, list)) and len(item) == 2:
                ep, path = int(item[0]), item[1]
                out.append((ep, path))
            else:
                path = str(item)
                m = _epoch_pat.search(os.path.basename(path))
                ep = int(m.group(1)) if m else (idx + 1)  
                out.append((ep, path))
        out.sort(key=lambda t: t[0])
        return out

    # ---- load saved checkpoints ---
    def _load_model(path: str):
        model = GeneratorCtor().to(device)
        state = torch.load(path, map_location=device)
        if isinstance(state, dict) and all(isinstance(k, str) for k in state.keys()) and \
           any(k.startswith(("layer", "module", "net", "fc", "conv")) or "." in k for k in state.keys()):
            model.load_state_dict(state)
        elif isinstance(state, dict) and "state_dict" in state:
            model.load_state_dict(state["state_dict"])
        else:
            model.load_state_dict(state)
        model.eval()
        return model

    sigmas = [float(s) for s in sigmas]
    ckpts_by_sigma = {float(s): _normalize_ckpts(checkpoint_map[s]) for s in sigmas}

    ncols = max(len(v) for v in ckpts_by_sigma.values()) + (1 if add_test_column else 0)
    nrows = len(sigmas)

    # --- figure --- 
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
                             figsize=(fig_width, row_height * nrows))
    if nrows == 1:
        axes = np.array([axes])  

    for r, sig in enumerate(sigmas):
        torch.manual_seed(seed + r)
        np.random.seed(seed + r)

        X_eval_vec = torch.full((sample_size, 1), fill_value=x_eval, dtype=torch.float32, device=device)
        eta_eval   = torch.randn(sample_size, noise_dim, device=device)

        Y_test = None
        if add_test_column:
            Y_test = generate_circle_given_X(sample_size, sig, x_eval)  # (n,2) np array

        ckpts = ckpts_by_sigma[sig]
        for c in range(ncols):
            ax = axes[r, c] if ncols > 1 else axes[r, 0]

            if c < len(ckpts):
                ep, path = ckpts[c]
                # generate
                model = _load_model(path)
                with torch.no_grad():
                    Y_gen = model(X_eval_vec, eta_eval).detach().cpu().numpy()
                # scatter
                ax.scatter(Y_gen[:, 0], Y_gen[:, 1], s=s, alpha=alpha)
                ax.set_title(f"ECMMD @ epoch {ep}", fontsize=11)
            else:
                # test column
                if add_test_column and Y_test is not None:
                    ax.scatter(Y_test[:, 0], Y_test[:, 1], s=s, alpha=alpha)
                    ax.set_title("Test (ground truth)", fontsize=11)
                else:
                    ax.axis("off")

            ax.set_xlabel(r"$y_1$", fontsize=12)
            if c == 0:
                ax.set_ylabel(r"$y_2$", fontsize=12)
            ax.grid(True, alpha=0.25)

        midc = min(len(ckpts)-1, ncols//2)
        target_ax = axes[r, midc] if ncols > 1 else axes[r, 0]
        row_tag = f"({chr(97 + r)})  " + r"$\sigma = $" + f"{sig}"
        target_ax.annotate(
            row_tag, xy=(0.5, -0.40), xycoords="axes fraction",
            ha="center", va="top", fontsize=row_label_fontsize
        )

        
        if save_csv:
            sig_tag = _fmt_sigma(sig)
            for ep, path in ckpts:
                out_csv = os.path.join(save_dir_csv, f"ecmmd_gen_sig{sig_tag}_ep{ep:04d}.csv")
                model = _load_model(path)
                with torch.no_grad():
                    Y_gen = model(X_eval_vec, eta_eval).detach().cpu().numpy()
                np.savetxt(out_csv, Y_gen, delimiter=",")
            if add_test_column and Y_test is not None:
                np.savetxt(os.path.join(save_dir_csv, f"test_sig{sig_tag}.csv"), Y_test, delimiter=",")

    plt.tight_layout()
    fig.subplots_adjust(hspace=row_gap_hspace, bottom=bottom_margin)

    tag_all = "_".join(_fmt_sigma(s) for s in sigmas)
    out_pdf = os.path.join(save_dir_fig, f"{panel_basename}_sigmas_{tag_all}.pdf")
    plt.savefig(out_pdf, format="pdf", bbox_inches="tight")
    plt.show()

    if return_fig:
        return fig

